rm(list = ls())
set.seed(12345)
library(bpCausal)
library(ggplot2)
## library(panelView)

## change wd
##setwd("/Users/lichengl/Desktop/BSA_replication_code/")




source("code/sim.R")
source("code/plot_function.R")
source("code/summary_function.R")

N <- 50   ## 50, 100
TT <- 50  ## 30, 50, 90 
Ntr <- N * 0.2 ## 10, 20
T0 <- TT - 10  ## 20, 40, 80 
p <- 5
beta <- c(1, 3, 0, 0, 0)
## beta <- c(1, 3)
r <- 2
twoway <- 1
c12 <- 0.25
c22 <- 0.25
c32 <- 0.25
c42 <- 0.25
c52 <- 0.25


Xname <- c("X1", "X2", "X3", "X4", "X5")
Xname2 <- c("X1", "X2", "U") ## oracal case where we know U 


datas <- simulate(N = N, TT = TT, Ntr = Ntr, T0 = T0, p = p,
                   beta = beta, r = 2, twoway = twoway, ATT = 0, 
                   c12 = c12, c22 = c22, c32 = c32,
                   c42 = c42, c52 = c52, coeff = 0.6, betau = 1) 
                   ## 1 unobserved confounder, sharp null case, it doesn't matter...

data <- datas$data 

index <- c("id", "time")
Yname <- "Y"
Dname <- "D"



## naive model
fit1 <-  bpCausal(data = data, index = index, 
                  Yname = Yname, Dname = Dname, Xname = Xname, 
                  Zname = NULL, Aname = NULL, re = "both", 
                  ar1 = FALSE, r = 10, niter = 15000, burn = 5000, 
                  xlasso = 1, zlasso = 1, alasso = 1, flasso = 1, 
                  a1 = 0.001, a2 = 0.001, b1 = 0.001, b2 = 0.001, 
                  c1 = 0.001, c2 = 0.001, p1 = 0.001, p2 = 0.001, 
                  sense = 0, c12 = 0.25, c22 = 0.25, 
                  c32 = 0.5, c42 = 0.25, c52 = 0.25)


## BSA model
fit12 <-  bpCausal(data = data, index = index, 
	              Yname = Yname, Dname = Dname, Xname = Xname, 
	              Zname = NULL, Aname = NULL, re = "both", 
                  ar1 = FALSE, r = 10, niter = 15000, burn = 5000, 
                  xlasso = 1, zlasso = 1, alasso = 1, flasso = 1, 
                  a1 = 0.001, a2 = 0.001, b1 = 0.001, b2 = 0.001, 
                  c1 = 0.001, c2 = 0.001, p1 = 1.001, p2 = 0.001, 
                  sense = 1, c12 = 0.25, c22 = 0.25, 
                  c32 = 0.5, c42 = 0.25, c52 = 0.25)


length(unique(c(fit12$sigma2)))/10000
## [1] 0.8049
length(unique(c(fit12$betau)))/10000
## [1] 0.459




## condition on betau or ld and then re-run BSA

## deciles of beta_u
betau.seq <- quantile(fit12$betau, probs = seq(from = 0.1, to = 0.9, by = 0.1))
## declies of lambda_d
ld.seq <- quantile(fit12$ld, probs = seq(from = 0.1, to = 0.9, by = 0.1))


## save the distribution of lambda_d given beta_u
sim.ld <- c()

## save att conditional on a particular value of beta_u
betau_catt <- matrix(NA, 9, 4)

for (i in 1:9) {

    fit13 <-  bpCausal(data = data, index = index, 
                  Yname = Yname, Dname = Dname, Xname = Xname, 
                  Zname = NULL, Aname = NULL, re = "both", 
                  ar1 = FALSE, r = 10, niter = 15000, burn = 5000, 
                  xlasso = 1, zlasso = 1, alasso = 1, flasso = 1, 
                  a1 = 0.001, a2 = 0.001, b1 = 0.001, b2 = 0.001, 
                  c1 = 0.001, c2 = 0.001, p1 = 1.001, p2 = 0.001, 
                  sense = 1, c12 = 0.25, c22 = 0.25, 
                  c32 = 0.5, c42 = 0.25, c52 = 0.25, betau0 = betau.seq[i])


    ## 
    avg1 <- effSummary(fit13, datt = FALSE) 
    betau_catt[i, ] <- c(betau.seq[i], mean(avg1$est.avg), quantile(avg1$est.avg, c(0.025, 0.975)))

    sim.ld <- c(sim.ld, list(fit13$ld))

}

colnames(betau_catt) <- c("betau", "att", "ci_l", "ci_u")




## save the distribution of beta_u given lambda_d
sim.betau <- c()

## save att conditional on a particular value of lambda_d
ld_catt <- matrix(NA, 9, 4)


for (i in 1:9) {

    fit14 <-  bpCausal(data = data, index = index, 
                  Yname = Yname, Dname = Dname, Xname = Xname, 
                  Zname = NULL, Aname = NULL, re = "both", 
                  ar1 = FALSE, r = 10, niter = 15000, burn = 5000, 
                  xlasso = 1, zlasso = 1, alasso = 1, flasso = 1, 
                  a1 = 0.001, a2 = 0.001, b1 = 0.001, b2 = 0.001, 
                  c1 = 0.001, c2 = 0.001, p1 = 1.001, p2 = 0.001, 
                  sense = 1, c12 = 0.25, c22 = 0.25, 
                  c32 = 0.5, c42 = 0.25, c52 = 0.25, ld0 = ld.seq[i])


    avg1 <- effSummary(fit14, datt = FALSE) 
    ld_catt[i, ] <- c(ld.seq[i], mean(avg1$est.avg), quantile(avg1$est.avg, c(0.025, 0.975)))

    sim.betau <- c(sim.betau, list(fit14$betau))

}

colnames(ld_catt) <- c("ld", "att", "ci_l", "ci_u")


## two additional cases: beta_u = 0 or lambda_d = 0

fit15 <-  bpCausal(data = data, index = index, 
                  Yname = Yname, Dname = Dname, Xname = Xname, 
                  Zname = NULL, Aname = NULL, re = "both", 
                  ar1 = FALSE, r = 10, niter = 15000, burn = 5000, 
                  xlasso = 1, zlasso = 1, alasso = 1, flasso = 1, 
                  a1 = 0.001, a2 = 0.001, b1 = 0.001, b2 = 0.001, 
                  c1 = 0.001, c2 = 0.001, p1 = 1.001, p2 = 0.001, 
                  sense = 1, c12 = 0.25, c22 = 0.25, 
                  c32 = 0.5, c42 = 0.25, c52 = 0.25, betau0 = 0)



fit16 <-  bpCausal(data = data, index = index, 
                  Yname = Yname, Dname = Dname, Xname = Xname, 
                  Zname = NULL, Aname = NULL, re = "both", 
                  ar1 = FALSE, r = 10, niter = 15000, burn = 5000, 
                  xlasso = 1, zlasso = 1, alasso = 1, flasso = 1, 
                  a1 = 0.001, a2 = 0.001, b1 = 0.001, b2 = 0.001, 
                  c1 = 0.001, c2 = 0.001, p1 = 1.001, p2 = 0.001, 
                  sense = 1, c12 = 0.25, c22 = 0.25, 
                  c32 = 0.5, c42 = 0.25, c52 = 0.25, ld0 = 0)                  







## save results
#save(datas, fit1, fit12, fit15, fit16, betau_catt, ld_catt, sim.ld, sim.betau,
#    file = "results/sim_sense_condition.RData")


#load("results/sim_sense_condition.RData")




## -------------------------------------------------------------------------- ##
## 1. plot of conditional vs marginal distributions of sensitivity parameters ##
## -------------------------------------------------------------------------- ##


## 1.1 posterior dependency: ld vs betau

v.sim.ld <- unlist(sim.ld)
## combined with original 
v.sim.ld <- c(v.sim.ld, rep(Inf, 10000), c(fit12$ld))  ## note: we add a gap between the deciles and marginal
index.ld <- rep(1:11, each = 10000)

## pdf("plot/sim0_ld_dist.pdf", width = 10, height = 7)
pdf("plot/figure5_left.pdf", width = 10, height = 7)
boxplot(v.sim.ld ~ index.ld, ylim = c(-3, 2), main = "",
        ylab = expression(lambda[d]), xlab = expression(beta[u]), xaxt='n') 
axis(1, at = c(1:9, 11), labels = c(round(betau_catt[, 1], digits = 3), ""), las = 1)

abline(h = quantile(c(fit12$ld), 0.25), col = "red", lty = 2)
abline(h = quantile(c(fit12$ld), 0.75), col = "red", lty = 2)

dev.off()



## 1.2 posterior dependency: betau vs ld

##pdf("plot/sim0_betau_dist.pdf", width = 10, height = 7)
pdf("plot/figure5_right.pdf", width = 10, height = 7)
v.sim.betau <- unlist(sim.betau)
## combined with original 
v.sim.betau <- c(v.sim.betau, rep(Inf, 10000), c(fit12$betau)) ## note: we add a gap between the deciles and marginal
index.betau <- rep(1:11, each = 10000)

boxplot(v.sim.betau ~ index.betau, ylim = c(-4, 4), main = "", 
        xlab = expression(lambda[d]), ylab = expression(beta[u]), xaxt = 'n')
axis(1, at = c(1:9, 11), labels = c(round(ld_catt[, 1], digits = 3), ""), las = 1)

abline(h = quantile(c(fit12$betau), 0.25), col = "red", lty = 2)
abline(h = quantile(c(fit12$betau), 0.75), col = "red", lty = 2)

dev.off()







## -------------------------------------------------------------------------- ##
##                       2. plot of treatment effects                         ##
## -------------------------------------------------------------------------- ##


## 2.0 data preparation

betau <- fit12$betau
ld <- fit12$ld 

dbu <- density(betau, n = 1000)
dld <- density(ld, n = 1000)

dbu_x <- dbu$x
dbu_y <- dbu$y
dl_x <- dld$x
dl_y <- dld$y


att0 <- 0 ## null case


betau <- fit12$betau
ld <- fit12$ld 





## 2.1 conditional treatment effects by deciles of beta_u

y.adj <- 4
ylim = c(0, 9)
y.seq = ylim[1]:ylim[2]

ff <- hist(betau, breaks = 100)
ff$counts <- ff$counts / max(ff$counts)



## plot

## pdf("plot/sim0_att_betau.pdf", width = 8, height = 7)
pdf("plot/figure6_left.pdf", width = 10, height = 7)
plot(ff, col = "grey", ylim = ylim, xaxt = 'n', yaxt = "n", 
         xlab = expression(beta[u]), ylab = "ATT", main = "",
         cex.lab = 1.5, cex.axis = 1.5, cex.main = 1.5, cex.sub = 1.5) 

## axis creation 
axis(1, at = betau_catt[, 1], labels = round(betau_catt[, 1], digits = 3), las = 1)
axis(2, at = y.seq, labels = y.seq - y.adj, las = 2, cex.axis = 1.5) 

points(betau_catt[, 1], betau_catt[, 2] + y.adj, pch = 19)

arrows(betau_catt[, 1], 
       betau_catt[, 3] + y.adj, 
       betau_catt[, 1], 
       betau_catt[, 4] + y.adj, 
       length = 0.1, angle = 90, code = 3)

abline(h = 0 + y.adj , lty = 3, lwd = 2, col = "red")

dev.off()







## 2.2 conditional treatment effects by deciles of lambda_d


y.adj <- 4
ylim = c(0, 9)
xlim = c(-3, 2)
y.seq = ylim[1]:ylim[2]

ff <- hist(ld, breaks = 100)
ff$counts <- ff$counts / max(ff$counts)


## plot

##pdf("plot/sim0_att_lambda_d.pdf", width = 8, height = 7)
pdf("plot/figure6_right.pdf", width = 10, height = 7)
plot(ff, col = "grey", ylim = ylim, xaxt = 'n', yaxt = "n", 
         xlab = expression(lambda[d]), ylab = "ATT", main = "",
         cex.lab = 1.5, cex.axis = 1.5, cex.main = 1.5, cex.sub = 1.5)

## axis creation 
axis(1, at = ld_catt[, 1], labels = round(ld_catt[, 1], digits = 3), las = 1)
axis(2, at = y.seq, labels = y.seq - y.adj, las = 2, cex.axis = 1.5) 

points(ld_catt[, 1], ld_catt[, 2] + y.adj, pch = 19)

arrows(ld_catt[, 1], 
       ld_catt[, 3] + y.adj, 
       ld_catt[, 1], 
       ld_catt[, 4] + y.adj, 
    length = 0.1, angle = 90, code = 3)

abline(h = 0 + y.adj , lty = 3, lwd = 2, col = "red")

dev.off()



## 2.2.1 contour plot for zero-effect 
avg1 <- effSummary(fit1)
att_naive <- mean(avg1$est.avg)

## samples of beta_u
bu_limit <- quantile(fit12$betau, probs = c(0.025, 0.975))
bu.seq <- seq(bu_limit[1], bu_limit[2], length.out = 100)
## samples of lambda_d
lda_limit <- quantile(fit12$ld, probs = c(0.025, 0.975))
lda.seq <- seq(lda_limit[1], lda_limit[2], length.out = 100)

att_adj <- outer(bu.seq, lda.seq, function(x, y) {att_naive - x * y})


##pdf("plot/sim0_att_contour.pdf", width = 8, height = 7)
pdf("plot/figure7.pdf", width = 10, height = 7)
contour(bu.seq, lda.seq, att_adj, nlevels = 30, 
        xlab = expression(beta[u]), ylab = expression(lambda[d]),
        cex.axis = 1.5, cex.lab = 1.5, col = "black", lwd = 1.5, labcex = 1.5,
        main = "")

# Add z = 0 level in red
contour(bu.seq, lda.seq, att_adj, levels = 0, add = TRUE,
        labcex = 1.5, col = "red", lwd = 2)

att_output <- round(att_naive, 3)
# Add text annotation
points(0, 0, pch = 19, cex = 1.2, col = "black")
text(0, 0, labels = paste("Unadjusted \n (", att_output, ")", sep = ""), pos = 3, cex = 1.2, col = "black")

dev.off()




## 2.3 compare naive, bsa, different level of lambda, ld 

## naive: 
avg1 <- effSummary(fit1)

## BSA
avg2 <- effSummary(fit12)

## ld = 0
avg3 <- effSummary(fit16)

## betau = 0
avg4 <- effSummary(fit15)


## pdf("plot/sim0_att_compare.pdf", width = 8, height = 7)
pdf("plot/figure8_right.pdf", width = 10, height = 7)
plot(c(0, 1, 2, 3, 4, 5), 
     c(NA, mean(avg1$est.avg), mean(avg2$est.avg), mean(avg3$est.avg), mean(avg4$est.avg), NA),
     ylim = c(-2, 3),
     pch = 19, xlab = "", ylab = "ATT",
     main = "", xaxt = "n", 
     cex.lab = 1.5, cex.axis = 1.5, cex.main = 1.5, cex.sub = 1.5)

# hack: we draw arrows but with very special "arrowheads"
arrows(c(0, 1, 2, 3, 4, 5), 
        c(NA, quantile(avg1$est.avg, 0.025), quantile(avg2$est.avg, 0.025), quantile(avg3$est.avg, 0.025), quantile(avg4$est.avg, 0.025), NA), 
        c(0, 1, 2, 3, 4, 5), 
        c(NA, quantile(avg1$est.avg, 0.975), quantile(avg2$est.avg, 0.975), quantile(avg3$est.avg, 0.975), quantile(avg3$est.avg, 0.975), NA), 
        length = 0.1, angle = 90, code = 3)

abline(h = 0, lty = 3, lwd = 2, col = "red")

axis(1, at = c(1, 2, 3, 4), 
     labels = c("w/o BSA", "w/ BSA", 
                expression(lambda[d] == 0), 
                expression(beta[u] == 0)),
     col.axis = "black", cex.axis = 1.5, cex.lab = 1.5)

dev.off()





## 2.4 dynamic ATT w/ and w/o BSA

TT <- dim(avg2$est.eff)[1]
pos <- which(rownames(avg2$est.eff) == "0")
period <- as.numeric(rownames(avg2$est.eff))

pos <- which(period >= -10 & period <= 6)
period.plot <- period[pos]


## r plot 
## pdf("plot/sim0_datt.pdf", width = 10, height = 7)
pdf("plot/figure8_left.pdf", width = 10, height = 7)
plot(period.plot, avg2$est.eff[pos, 5], col = "red", 
     type = "l", lwd = 5, ylim = c(-6, 15), 
     xlab = "Time relative to Treatment", ylab = "ATT", main = "", 
     cex.lab=1.5, cex.axis=1.5, cex.main=1.5, cex.sub=1.5)

lines(period.plot, avg2$est.eff[pos, 6], col = "red", lwd = 3, lty = 3)
lines(period.plot, avg2$est.eff[pos, 7], col = "red", lwd = 3, lty = 3)

lines(period.plot, avg1$est.eff[pos, 5], col = "blue", lwd = 5)
lines(period.plot, avg1$est.eff[pos, 6], col = "blue", lwd = 3, lty = 3)
lines(period.plot, avg1$est.eff[pos, 7], col = "blue", lwd = 3, lty = 3)

abline(v = 0, lty = 3, lwd = 2)
abline(h = 0, lty = 3, lwd = 2)

legend(-10, 15, legend = c("W/ BSA", "W/o BSA"),
       col = c("red", "blue"), lty = 1:1, lwd = 2:2, pt.cex = 1.5)

dev.off()







